{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "0",
   "metadata": {},
   "source": [
    "This example requires the following dependencies to be installed:\n",
    "pip install \"lightly[timm]\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1",
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install \"lightly[timm]\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Note: The model and training settings do not follow the reference settings\n",
    "# from the paper. The settings are chosen such that the example can easily be\n",
    "# run on a small dataset with a single GPU.\n",
    "import pytorch_lightning as pl\n",
    "import torch\n",
    "import torchvision\n",
    "from timm.models.vision_transformer import vit_base_patch32_224\n",
    "from torch import nn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3",
   "metadata": {},
   "outputs": [],
   "source": [
    "from lightly.models import utils\n",
    "from lightly.models.modules import MAEDecoderTIMM, MaskedVisionTransformerTIMM\n",
    "from lightly.transforms import MAETransform"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4",
   "metadata": {},
   "outputs": [],
   "source": [
    "class MAE(pl.LightningModule):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "\n",
    "        decoder_dim = 512\n",
    "        vit = vit_base_patch32_224()\n",
    "        self.mask_ratio = 0.75\n",
    "        self.patch_size = vit.patch_embed.patch_size[0]\n",
    "        self.backbone = MaskedVisionTransformerTIMM(vit=vit)\n",
    "        self.sequence_length = self.backbone.sequence_length\n",
    "        self.decoder = MAEDecoderTIMM(\n",
    "            num_patches=vit.patch_embed.num_patches,\n",
    "            patch_size=self.patch_size,\n",
    "            embed_dim=vit.embed_dim,\n",
    "            decoder_embed_dim=decoder_dim,\n",
    "            decoder_depth=1,\n",
    "            decoder_num_heads=16,\n",
    "            mlp_ratio=4.0,\n",
    "            proj_drop_rate=0.0,\n",
    "            attn_drop_rate=0.0,\n",
    "        )\n",
    "        self.criterion = nn.MSELoss()\n",
    "\n",
    "    def forward_encoder(self, images, idx_keep=None):\n",
    "        return self.backbone.encode(images=images, idx_keep=idx_keep)\n",
    "\n",
    "    def forward_decoder(self, x_encoded, idx_keep, idx_mask):\n",
    "        # build decoder input\n",
    "        batch_size = x_encoded.shape[0]\n",
    "        x_decode = self.decoder.embed(x_encoded)\n",
    "        x_masked = utils.repeat_token(\n",
    "            self.decoder.mask_token, (batch_size, self.sequence_length)\n",
    "        )\n",
    "        x_masked = utils.set_at_index(x_masked, idx_keep, x_decode.type_as(x_masked))\n",
    "\n",
    "        # decoder forward pass\n",
    "        x_decoded = self.decoder.decode(x_masked)\n",
    "\n",
    "        # predict pixel values for masked tokens\n",
    "        x_pred = utils.get_at_index(x_decoded, idx_mask)\n",
    "        x_pred = self.decoder.predict(x_pred)\n",
    "        return x_pred\n",
    "\n",
    "    def training_step(self, batch, batch_idx):\n",
    "        views = batch[0]\n",
    "        images = views[0]  # views contains only a single view\n",
    "        batch_size = images.shape[0]\n",
    "        idx_keep, idx_mask = utils.random_token_mask(\n",
    "            size=(batch_size, self.sequence_length),\n",
    "            mask_ratio=self.mask_ratio,\n",
    "            device=images.device,\n",
    "        )\n",
    "        x_encoded = self.forward_encoder(images=images, idx_keep=idx_keep)\n",
    "        x_pred = self.forward_decoder(\n",
    "            x_encoded=x_encoded, idx_keep=idx_keep, idx_mask=idx_mask\n",
    "        )\n",
    "\n",
    "        # get image patches for masked tokens\n",
    "        patches = utils.patchify(images, self.patch_size)\n",
    "        # must adjust idx_mask for missing class token\n",
    "        target = utils.get_at_index(patches, idx_mask - 1)\n",
    "\n",
    "        loss = self.criterion(x_pred, target)\n",
    "        return loss\n",
    "\n",
    "    def configure_optimizers(self):\n",
    "        optim = torch.optim.AdamW(self.parameters(), lr=1.5e-4)\n",
    "        return optim"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = MAE()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6",
   "metadata": {},
   "outputs": [],
   "source": [
    "transform = MAETransform()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# we ignore object detection annotations by setting target_transform to return 0\n",
    "def target_transform(t):\n",
    "    return 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = torchvision.datasets.VOCDetection(\n",
    "    \"datasets/pascal_voc\",\n",
    "    download=True,\n",
    "    transform=transform,\n",
    "    target_transform=target_transform,\n",
    ")\n",
    "# or create a dataset from a folder containing images or videos:\n",
    "# dataset = LightlyDataset(\"path/to/folder\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataloader = torch.utils.data.DataLoader(\n",
    "    dataset,\n",
    "    batch_size=256,\n",
    "    shuffle=True,\n",
    "    drop_last=True,\n",
    "    num_workers=8,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "10",
   "metadata": {},
   "outputs": [],
   "source": [
    "accelerator = \"gpu\" if torch.cuda.is_available() else \"cpu\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "11",
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer = pl.Trainer(\n",
    "    max_epochs=10,\n",
    "    devices=1,\n",
    "    accelerator=accelerator,\n",
    ")\n",
    "trainer.fit(model=model, train_dataloaders=dataloader)"
   ]
  }
 ],
 "metadata": {
  "jupytext": {
   "cell_metadata_filter": "-all",
   "main_language": "python",
   "notebook_metadata_filter": "-all"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
